/*
 * Copyright (C) 2006-2021  Music Technology Group - Universitat Pompeu Fabra
 *
 * This file is part of Essentia
 *
 * Essentia is free software: you can redistribute it and/or modify it under
 * the terms of the GNU Affero General Public License as published by the Free
 * Software Foundation (FSF), either version 3 of the License, or (at your
 * option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 * FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
 * details.
 *
 * You should have received a copy of the Affero GNU General Public License
 * version 3 along with this program.  If not, see http://www.gnu.org/licenses/
 */

#include "pitchcrepe.h"

using namespace std;
using namespace essentia;
using namespace standard;

const char* PitchCREPE::name = "PitchCREPE";
const char* PitchCREPE::category = "Pitch";
const char* PitchCREPE::description = DOC(
  "This algorithm estimates pitch of monophonic audio signals using CREPE models.\n"
  "\n"
  "This algorithm is a wrapper to post-process the activations generated by TensorflowPredictCREPE. "
  "`time` contains the timestamps in which the pitch was estimated. `frequency` is the vector of "
  "pitch estimations in Hz. `confidence` expresses the confidence in the presence of pitch for each "
  "timestamp as value between 0 to 1. `activations` is a time by sigmoid activations matrix returned "
  " by the neural network.\n"
  "\n"
  "See TensorflowPredictCREPE for details about the rest of parameters.\n"
  "The recommended pipeline is as follows::\n"
  "\n"
  "  MonoLoader(sampleRate=16000) >> PitchCREPE()\n"
  "\n"
  "Notes:\n"
  "This algorithm does not make any check on the input model so it is "
  "the user's responsibility to make sure it is a valid one.\n"
  "The required sample rate of input signal is 16 KHz. "
  "Other sample rates will lead to an incorrect behavior.\n"
  "\n"
  "References:\n"
  "\n"
  "1. CREPE: A Convolutional Representation for Pitch Estimation. "
  "Jong Wook Kim, Justin Salamon, Peter Li, Juan Pablo Bello. "
  "Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal "
  "Processing (ICASSP), 2018.\n"
  "\n"
  "2. Original models and code at https://github.com/marl/crepe/\n"
  "\n"
  "3. Supported models at https://essentia.upf.edu/models/\n\n");


void PitchCREPE::configure() {
  _tensorflowPredictCREPE->configure(INHERIT("graphFilename"),
                                     INHERIT("savedModel"),
                                     INHERIT("input"),
                                     INHERIT("output"),
                                     INHERIT("hopSize"),
                                     INHERIT("batchSize"));

  _hopSize = parameter("hopSize").toFloat();
  // _viterbi = parameter("viterbi").toBool();

  for (int i = 0; i < _nPitches - 1; ++i) {
    _centsMapping.push_back(_delta * i + _shift);
  }
  _centsMapping.push_back(_end + _shift);
}


void PitchCREPE::compute() {
  const vector<Real>& audio = _audio.get();
  vector<Real>& time = _time.get();
  vector<Real>& frequency = _frequency.get();
  vector<Real>& confidence = _confidence.get();
  vector<vector<Real> >& activations = _activations.get();

  _tensorflowPredictCREPE->input("signal").set(audio);
  _tensorflowPredictCREPE->output("predictions").set(activations);
  _tensorflowPredictCREPE->compute();

  int timestamps = activations.size();

  time.resize(timestamps);
  frequency.resize(timestamps);
  confidence.resize(timestamps);
  
  frequency.assign(timestamps, 0.0);
  confidence.assign(timestamps, 0.0);

  // Activations to cents.
  vector<Real> cents = toLocalAverageCents(activations);

  for (int i = 0; i < timestamps; i++) {
    // Get the timestamp of each pitch prediction.
    time[i] = i * _hopSize / 1000.0;

    if (!isnan(cents[i])) {
      // Cents to frequencies.
      frequency[i] = 10.0 * pow(2.0, (cents[i] / 1200.0));

      // Get the confidences (higher activation per timestamp).
      // Our argmax implementation relies on max_element,
      // which returns the first element in case of multiple candidates.
      // While this may result in unintended prediction flaws, for now we
      // leave it like this so that we match the original implementation.
      confidence[i] = activations[i][argmax(activations[i])];
    }
  }
}

vector<Real> PitchCREPE::toLocalAverageCents(vector<vector<Real> > &activations) {
  int timestamps = activations.size();
  vector<Real> cents(timestamps, 0);

  int center, start, end;
  Real productSum, weightSum;
  vector<Real> salience, toCents;
  for (int i = 0; i < timestamps; i++) {
    center = argmax(activations[i]);
    start = max(0, center - 4);
    end = min(_nPitches, center + 5);

    salience.assign(activations[i].begin() + start, activations[i].begin() + end);
    toCents.assign(_centsMapping.begin() + start, _centsMapping.begin() + end);

    productSum = dotProduct(salience, toCents);
    weightSum = sum(salience);
    cents[i] = productSum / weightSum;
  }

  return cents;
}
